{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[View the runnable example on GitHub](https://github.com/intel-analytics/BigDL/tree/main/python/nano/tutorial/notebook/training/pytorch-lightning/pytorch_lightning_training_bf16.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use BFloat16 Mixed Precision for PyTorch Lightning Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Brain Floating Point Format (BFloat16) is a custom 16-bit floating point format designed for machine learning. BFloat16 is comprised of 1 sign bit, 8 exponent bits, and 7 mantissa bits. With the same number of exponent bits, BFloat16 has the same dynamic range as FP32, but requires only half the memory usage.\n",
    "\n",
    "BFloat16 Mixed Precison combines BFloat16 and FP32 during training, which could lead to increased performance and reduced memory usage. Compared to FP16 mixed precison, BFloat16 mixed precision has better numerical stability.\n",
    "\n",
    "`bigdl.nano.pytorch.Trainer` API extends PyTorch Lightning Trainer with multiple integrated optimizations. You could instantiate a BigDL-Nano `Trainer` with `precision='bf16'` to use BFloat16 mixed precision for training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "source": [
    "To use BFloat16 mixed precision in PyTorch Lightning Training, you need to install BigDL-Nano for PyTorch first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "!pip install --pre --upgrade bigdl-nano[pytorch] # install the nightly-built version\n",
    "!source bigdl-nano-init # set environment variables"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    ">\n",
    "> Before starting your PyTorch Lightning application, it is highly recommended to run `source bigdl-nano-init` to set several environment variables based on your current hardware. Empirically, these variables will bring big performance increase for most PyTorch Lightning applications on training workloads."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "source": [
    "> ⚠️ **Warning**\n",
    "> \n",
    "> For Jupyter Notebook users, we recommend to run the commands above, especially `source bigdl-nano-init` before jupyter kernel is started, or some of the optimizations may not take effect."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a self-defined `LightningModule` (based on a [ResNet-18 model](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html) pretrained on ImageNet dataset) and dataloaders to finetune the model on [OxfordIIITPet dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.OxfordIIITPet.html) as an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "# Define LightningModule and dataloader\n",
    "\n",
    "import torch\n",
    "from torchvision.models import resnet18\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "class MyLightningModule(pl.LightningModule):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.model = resnet18(pretrained=True)\n",
    "        num_ftrs = self.model.fc.in_features\n",
    "        # here the size of each output sample is set to 37.\n",
    "        self.model.fc = torch.nn.Linear(num_ftrs, 37)\n",
    "        self.criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.model(x)\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        output = self.model(x)\n",
    "        loss = self.criterion(output, y)\n",
    "        self.log('train_loss', loss)\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        output = self.forward(x)\n",
    "        loss = self.criterion(output, y)\n",
    "        pred = torch.argmax(output, dim=1)\n",
    "        acc = torch.sum(y == pred).item() / (len(y) * 1.0)\n",
    "        metrics = {'test_acc': acc, 'test_loss': loss}\n",
    "        self.log_dict(metrics)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n",
    "\n",
    "\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import OxfordIIITPet\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "def create_dataloaders():\n",
    "    train_transform = transforms.Compose([transforms.Resize(256),\n",
    "                                          transforms.RandomCrop(224),\n",
    "                                          transforms.RandomHorizontalFlip(),\n",
    "                                          transforms.ColorJitter(brightness=.5, hue=.3),\n",
    "                                          transforms.ToTensor(),\n",
    "                                          transforms.Normalize([0.485, 0.456, 0.406],\n",
    "                                                               [0.229, 0.224, 0.225])])\n",
    "    val_transform = transforms.Compose([transforms.Resize(256),\n",
    "                                        transforms.CenterCrop(224),\n",
    "                                        transforms.ToTensor(),\n",
    "                                        transforms.Normalize([0.485, 0.456, 0.406],\n",
    "                                                             [0.229, 0.224, 0.225])])\n",
    "\n",
    "    # apply data augmentation to the tarin_dataset\n",
    "    train_dataset = OxfordIIITPet(root=\"/tmp/data\", transform=train_transform, download=True)\n",
    "    val_dataset = OxfordIIITPet(root=\"/tmp/data\", transform=val_transform)\n",
    "\n",
    "    # obtain training indices that will be used for validation\n",
    "    indices = torch.randperm(len(train_dataset))\n",
    "    val_size = len(train_dataset) // 4\n",
    "    train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])\n",
    "    val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])\n",
    "\n",
    "    # prepare data loaders\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=32)\n",
    "    val_dataloader = DataLoader(val_dataset, batch_size=32)\n",
    "\n",
    "    return train_dataloader, val_dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MyLightningModule()\n",
    "train_loader, val_loader = create_dataloaders()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; _The definition of_ `MyLightningModule` _and_ `create_dataloaders` _can be found in the_ [runnable example](https://github.com/intel-analytics/BigDL/tree/main/python/nano/tutorial/notebook/training/pytorch-lightning/pytorch_lightning_training_bf16.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To use BFloat16 mixed precision for your PyTorch Lightning application, you could simply **import BigDL-Nano Trainer, and set** `precision` **to be** `'bf16'`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.nano.pytorch import Trainer\n",
    "\n",
    "trainer = Trainer(max_epochs=5, precision='bf16')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    ">\n",
    "> BFloat16 mixed precision in PyTorch Lightning applications requires `torch>=1.10`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> ⚠️ **Warning**\n",
    "> \n",
    "> Using BFloat16 mixed precision with `torch<1.12` may result in extremely slow training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also set `use_ipex=True` and `precision='bf16'` at the meantime to enable IPEX ([Intel® Extension for PyTorch*](https://github.com/intel/intel-extension-for-pytorch)) optimizer fusion for BFloat16 mixed precision training to gain more acceleration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.nano.pytorch import Trainer\n",
    "\n",
    "trainer = Trainer(max_epochs=5, use_ipex=True, precision='bf16')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    ">\n",
    "> `Trainer(..., use_ipex=True, precision='bf16')` intends to disable PyTorch native Auto Mixed Precision (AMP) and enable AMP from IPEX."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You could then do BFloat16 mixed precision training and evaluation as normal:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.fit(model, train_dataloaders=train_loader)\n",
    "trainer.validate(model, dataloaders=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📚 **Related Readings**\n",
    "> \n",
    "> - [How to install BigDL-Nano](https://bigdl.readthedocs.io/en/latest/doc/Nano/Overview/install.html)\n",
    "> - [How to accelerate a PyTorch Lightning application on training workloads through Intel® Extension for PyTorch*](https://bigdl.readthedocs.io/en/latest/doc/Nano/Howto/Training/PyTorchLightning/accelerate_pytorch_lightning_training_ipex.html)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.13 ('nano-pytorch': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.7.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "09344c7f3239fd422839751f876786d6b1a624c40f19af1b43cb2737f421c2b2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
